{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Определение автора произведения по предложению из текста\n",
    "Один мой друг рассказывал, что в студенческие годы удивлял одногруппников тем, что мог почти безошибочно определить писателя по отрывку из его произведения. Попробуем что-то похожее сделать средствами машинного обучения. В качестве данных взят training set из уже завершённого соревнования Spooky Author Identification (https://www.kaggle.com/c/spooky-author-identification/data).\n",
    "\n",
    "Ценность задачи может состоять, например, в том, что похожими методами можно определить автора текста, писавшего под псевдонимом (если известен примерный круг \"подозреваемых\")."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib\n",
    "from matplotlib import pyplot as plt\n",
    "%matplotlib inline\n",
    "import re\n",
    "import string\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import log_loss, accuracy_score\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from scipy.sparse import csr_matrix, hstack\n",
    "from sklearn.naive_bayes import MultinomialNB\n",
    "from sklearn.model_selection import ShuffleSplit\n",
    "from sklearn.model_selection import learning_curve\n",
    "from sklearn.model_selection import validation_curve"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Первичный анализ\n",
    "Загрузим данные и посмотрим на них."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv('../../data/spooky_authors.csv')\n",
    "data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.unique(data['author'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Набор данных содержит три поля:\n",
    "* id - идентификатор строки (не представляет интереса)\n",
    "* text - предложение из текста произведения одного из трёх авторов\n",
    "* author - автор, а именно:\n",
    "    * EAP - Edgar Allan Poe\n",
    "    * HPL - H.P. Lovecraft\n",
    "    * MWS - Mary Wollstonecraft Shelley\n",
    "    \n",
    "Посмотрим, сбалансированы ли классы."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.groupby('author')['id'].count().plot(kind='bar')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Количество данных по разным авторам несколько отличается, но не слишком сильно.\n",
    "\n",
    "Вытащим из текста слова и посмотрим, отличаются ли у авторов количество слов в предложении, средняя длина слова, количество пунктуации и т.д."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "words = data['text'].apply(lambda t: [w.lower() for w in re.findall(r'\\b[^\\W\\d_]+\\b', t)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Сначала посмотрим на количество слов в предложении."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data['length'] = words.apply(lambda ws: len(ws))\n",
    "data.boxplot(column='length',by='author')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Пока видно только различие по аномально длинным предложениям. Выбросы достигают 800+ слов, но только единичные предложения выходят за 200 слов, а основная масса лежит где-то в пределах пятидесяти."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data['length'].describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Построим такой же график, но исключив крайние выбросы."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data[data['length'] < 200].boxplot(column='length',by='author')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Теперь отличия более заметны (например, короткие предложения чаще встречаются у Эдгара По, а у Лавкрафта самое узкое расределение, т.е. очень длинные и очень короткие предложения скорее всего не его).\n",
    "\n",
    "Посмотрим теперь на количество уникальных слов в предложении."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data['unique'] = words.apply(lambda ws: len(set(ws)) / len(ws))\n",
    "data.boxplot(column='unique',by='author')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "И вновь у Эдгара По больше крайних значений, а у Лавкрафта самое узкое распределение.\n",
    "\n",
    "Посмотрим на среднюю длину слова в предложении."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data['word_length'] = words.apply(lambda ws: np.mean([len(w) for w in ws]))\n",
    "data.boxplot(column='word_length',by='author')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "И вновь можно сказать, что самые низкие значения характерны для По, а у Лавкрафта практически нет ни очень высоких, ни очень низких значений.\n",
    "\n",
    "Наконец, посмотрим, какая часть символов является пунктуацией."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data['punctuation'] = data['text'].apply(lambda t: sum(map(lambda letter: letter in string.punctuation, t)) / len(t))\n",
    "data.boxplot(column='punctuation',by='author')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Из-за выбросов график трудно интерпретировать, посмотрим ближе значения, не превышающие 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data[data['punctuation'] < 0.2].boxplot(column='punctuation',by='author')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Теперь отличия видны лучше. Наибольшее количество пунктуации характерно для По, наименьшее - для Лавкрафта.\n",
    "\n",
    "Матрица корреляций этих признаков скорее всего будет не очень информативной, но раз она несколько раз упоминается в плане и критериях оценки, то построим её."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data[['length', 'unique', 'word_length', 'punctuation']].corr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.plotting.scatter_matrix(data[['length', 'unique', 'word_length', 'punctuation']], alpha = 0.3, figsize = (14,8), diagonal = 'kde');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Как и ожидалось, заметная корреляция только одна: отрицательная корреляция между длиной предложения и количеством уникальных слов в нём. Это логично: чем длиннее предложение, тем вероятнее в нём будут повторяться слова. Если бы значение корреляции было близко к -1, один из признаков стоило бы откинуть, но поскольку только -0.63, то пусть остаются оба."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Обучение модели\n",
    "\n",
    "Перейдём собственно к обучению. Первым делом надо ответить на следующие вопросы:\n",
    "* Что выбрать в качестве метрики качества.\n",
    "* Использовать ли TF-IDF или достаточно просто CountVectorizer.\n",
    "* Какой алгоритм использовать. В курсе для подобных задач мы всегда использовали линейные методы, но для задач классификации текстов классическим вариантом также является наивный байесовский классификатор.\n",
    "\n",
    "В качестве метрики качества можно было бы выбрать accuracy score или log-loss. Accuracy score для данной задачи была бы вполне применима (классы более-менее сбалансированы, и у нас нет предпочтений, в какую сторону менее болезненно было бы ошибаться), тем более, что у accuracy score очень хорошая интерпретируемость. Но log-loss является более тонкой метрикой, позволяющей оптимизировать не только правильность ответа, но и степень уверенности в нём.\n",
    "\n",
    "Т.о. в качестве метрики качества будет взят log-loss (но accuracy score будем тоже выводить, просто для информации).\n",
    "\n",
    "Что касается использования TF-IDF и выбора алгоритма ответ прост: мы попробуем разные варианты и посмотрим, что работает лучше. Линейные методы должны работать неплохо для разреженных данных большой размерности (которые у нас получатся после векторизации). Байсовские методы тоже должны неплохо справляться с такими задачами.\n",
    "\n",
    "Кроме того, мы сначала попробуем обучать только на самих текстах, а потом посмотрим, дадут ли прирост проанализированные фичи вроде длины предложения и количества пунктуации. Скорее всего, эти признаки смогут немного улучшить модель, т.к. такие вещи как длина предложения, длина слов, знаки препинания отражают стиль писателя наравне с используемыми словами (и как мы видели выше у разных писателей на самом деле эти параметры немного отличаются).\n",
    "\n",
    "Кросс-валидацию будем проводить сразу вместе с поиском оптимальных параметров через GridSearchCV. По умолчанию GridSearchCV использует StratifiedKFold (то есть все во фолды попадёт примерно равное количество представителей каждого класса) и количество фолдов, равное 3. Эти умолчания представляются разумными, так что оставим их.\n",
    "\n",
    "Для надёжности также выделим отдельно validation set чтобы проверять на нём промежуточные результаты и test set, на котором проверим итоговую модель в самом конце.\n",
    "\n",
    "Итак, разобьём данные на train, validation и test."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = data.drop(['id', 'author'], axis=1)\n",
    "author_encoder = LabelEncoder()\n",
    "y = author_encoder.fit_transform(data['author'])\n",
    "\n",
    "X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.2, random_state=17)\n",
    "X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.2, random_state=17)\n",
    "\n",
    "text_train = X_train['text']\n",
    "text_val = X_val['text']\n",
    "text_train_val = X_train_val['text']\n",
    "text_test = X_test['text']\n",
    "\n",
    "print(X.shape, y.shape)\n",
    "print(X_train_val.shape, X_test.shape, y_train_val.shape, y_test.shape)\n",
    "print(X_train.shape, X_val.shape, y_train.shape, y_val.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Напишем небольшую функцию для вывода результатов обучения:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_results(clf, x_val, y_val):\n",
    "    y_pred_val_proba = clf.predict_proba(x_val)\n",
    "    y_pred_val_labels = clf.predict(x_val)\n",
    "    print(\"Validation set log loss: {} (accuracy: {})\".format(log_loss(y_val, y_pred_val_proba), accuracy_score(y_val, y_pred_val_labels)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Подготовим векторайзеры и преобразуем данные."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_vectorizer = CountVectorizer()\n",
    "count_train = count_vectorizer.fit_transform(text_train)\n",
    "count_val = count_vectorizer.transform(text_val)\n",
    "\n",
    "tfidf_vectorizer = TfidfVectorizer()\n",
    "tfidf_train = tfidf_vectorizer.fit_transform(text_train)\n",
    "tfidf_val = tfidf_vectorizer.transform(text_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Логистическая регрессия\n",
    "\n",
    "Начнём с самого простого варианта: CountVectorizer и логистическая регрессия."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logit_count_clf = LogisticRegression().fit(count_train, y_train)\n",
    "print_results(logit_count_clf, count_val, y_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Уже неплохо. Попробуем настроить регуляризацию."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logit_count_grid = GridSearchCV(LogisticRegression(random_state=17), {'C': [0.1, 0.3, 1, 3, 10, 30, 100, 300]}, scoring='neg_log_loss')\n",
    "logit_count_grid.fit(count_train, y_train)\n",
    "print(logit_count_grid.best_params_, ', cross validation score:', -logit_count_grid.best_score_)\n",
    "print_results(logit_count_grid, count_val, y_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Значение по умолчанию оказалось оптимальным.\n",
    "\n",
    "Попробуем TF-IDF."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logit_tfidf_clf = LogisticRegression(random_state=17).fit(tfidf_train, y_train)\n",
    "print_results(logit_tfidf_clf, tfidf_val, y_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Результат существенно хуже (видно, что accuracy изменилась несущественно, но уверены в предсказаниях мы намного меньше).\n",
    "\n",
    "Возможно, дело в регуляризации."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logit_tfidf_grid = GridSearchCV(LogisticRegression(random_state=17), {'C': [0.1, 0.3, 1, 3, 10, 30, 100, 300]}, scoring='neg_log_loss')\n",
    "logit_tfidf_grid.fit(tfidf_train, y_train)\n",
    "print(logit_tfidf_grid.best_params_, ', cross validation score:', -logit_tfidf_grid.best_score_)\n",
    "print_results(logit_tfidf_grid, tfidf_val, y_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "После настройки регуляризации результат намного лучше (и чуть лучше, чем без TF-IDF).\n",
    "\n",
    "Проверим, изменится ли результат, если применить логарифмическое преобразование к TF."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfidf_log_vectorizer = TfidfVectorizer(sublinear_tf=True)\n",
    "tfidf_log_train = tfidf_log_vectorizer.fit_transform(text_train)\n",
    "tfidf_log_val = tfidf_log_vectorizer.transform(text_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logit_tfidf_log_grid = GridSearchCV(LogisticRegression(random_state=17), {'C': [0.1, 0.3, 1, 3, 10, 30, 100, 300]}, scoring='neg_log_loss')\n",
    "logit_tfidf_log_grid.fit(tfidf_log_train, y_train)\n",
    "print(logit_tfidf_log_grid.best_params_, ', cross validation score:', -logit_tfidf_log_grid.best_score_)\n",
    "print_results(logit_tfidf_log_grid, tfidf_log_val, y_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Неа.\n",
    "\n",
    "Попробуем, не будет ли результат лучше, если использовать биграммы."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfidf_12_vectorizer = TfidfVectorizer(ngram_range=(1, 2))\n",
    "tfidf_12_train = tfidf_12_vectorizer.fit_transform(text_train)\n",
    "tfidf_12_val = tfidf_12_vectorizer.transform(text_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logit_tfidf_12_grid = GridSearchCV(LogisticRegression(random_state=17), {'C': [0.1, 0.3, 1, 3, 10, 30, 100, 300]}, scoring='neg_log_loss')\n",
    "logit_tfidf_12_grid.fit(tfidf_12_train, y_train)\n",
    "print(logit_tfidf_12_grid.best_params_, ', cross validation score:', -logit_tfidf_12_grid.best_score_)\n",
    "print_results(logit_tfidf_12_grid, tfidf_12_val, y_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Немного лучше.\n",
    "А если триграммы?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfidf_13_vectorizer = TfidfVectorizer(ngram_range=(1, 3))\n",
    "tfidf_13_train = tfidf_13_vectorizer.fit_transform(text_train)\n",
    "tfidf_13_val = tfidf_13_vectorizer.transform(text_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logit_tfidf_13_grid = GridSearchCV(LogisticRegression(random_state=17), {'C': [1, 3, 10, 30, 100, 300]}, scoring='neg_log_loss')\n",
    "logit_tfidf_13_grid.fit(tfidf_13_train, y_train)\n",
    "print(logit_tfidf_13_grid.best_params_, ', cross validation score:', -logit_tfidf_13_grid.best_score_)\n",
    "print_results(logit_tfidf_13_grid, tfidf_13_val, y_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Оставим биграммы. И попробуем добавить наши дополнительные фичи."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfidf_12_ext_train = hstack([tfidf_12_train, X_train[['length', 'unique', 'word_length', 'punctuation']]]).tocsr()\n",
    "tfidf_12_ext_val = hstack([tfidf_12_val, X_val[['length', 'unique', 'word_length', 'punctuation']]]).tocsr()\n",
    "\n",
    "logit_tfidf_12_ext_grid = GridSearchCV(LogisticRegression(), {'C': [1, 3, 10, 30, 100, 300]}, scoring='neg_log_loss')\n",
    "logit_tfidf_12_ext_grid.fit(tfidf_12_ext_train, y_train)\n",
    "print(logit_tfidf_12_ext_grid.best_params_, -logit_tfidf_12_ext_grid.best_score_)\n",
    "print_results(logit_tfidf_12_ext_grid, tfidf_12_ext_val, y_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Прирост не то чтобы очень большой, но есть."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Кривые валидации и обучения\n",
    "Построим кривую обучения (log loss в зависимости от размера обучающей выборки)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None, n_jobs=1, train_sizes=np.linspace(.1, 1.0, 5)):\n",
    "    plt.figure()\n",
    "    plt.title(title)\n",
    "    if ylim is not None:\n",
    "        plt.ylim(*ylim)\n",
    "    plt.xlabel(\"Training examples\")\n",
    "    plt.ylabel(\"Score\")\n",
    "    \n",
    "    train_sizes, train_scores, test_scores = learning_curve(\n",
    "        estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes, scoring='neg_log_loss')\n",
    "    train_scores_mean = -np.mean(train_scores, axis=1)\n",
    "    train_scores_std = np.std(train_scores, axis=1)\n",
    "    test_scores_mean = -np.mean(test_scores, axis=1)\n",
    "    test_scores_std = np.std(test_scores, axis=1)\n",
    "    plt.grid()\n",
    "\n",
    "    plt.fill_between(train_sizes, train_scores_mean - train_scores_std,\n",
    "                     train_scores_mean + train_scores_std, alpha=0.1,\n",
    "                     color=\"r\")\n",
    "    plt.fill_between(train_sizes, test_scores_mean - test_scores_std,\n",
    "                     test_scores_mean + test_scores_std, alpha=0.1, color=\"g\")\n",
    "    plt.plot(train_sizes, train_scores_mean, 'o-', color=\"r\",\n",
    "             label=\"Training score\")\n",
    "    plt.plot(train_sizes, test_scores_mean, 'o-', color=\"g\",\n",
    "             label=\"Cross-validation score\")\n",
    "\n",
    "    plt.legend(loc=\"best\")\n",
    "    return plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfidf_12_vectorizer_train_val = tfidf_12_vectorizer = TfidfVectorizer(ngram_range=(1, 2))\n",
    "tfidf_12_train_val = tfidf_12_vectorizer_train_val.fit_transform(text_train_val)\n",
    "tfidf_12_ext_train_val = hstack([tfidf_12_train_val, X_train_val[['length', 'unique', 'word_length', 'punctuation']]]).tocsr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "title = \"Learning Curves (Logistic Regression)\"\n",
    "cv = ShuffleSplit(n_splits=4, test_size=0.2, random_state=0)\n",
    "plot_learning_curve(LogisticRegression(C=100), title, tfidf_12_ext_train_val, y_train_val, ylim=(0, 0.8), cv=cv, n_jobs=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Кривая, соответствующая кросс-валидации, спускается вниз, но находится очень далеко от кривой, соответствующей обучению. Это говорит о том, что у нас большое недообучение и может помочь больше данных.\n",
    "\n",
    "Построим кривую валидации (для разных параметров C)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "param_range = np.logspace(-1, 3, 5)\n",
    "train_scores, test_scores = validation_curve(\n",
    "    LogisticRegression(C=100), tfidf_12_ext_train_val, y_train_val, param_name=\"C\", param_range=param_range,\n",
    "    cv=4, scoring=\"neg_log_loss\", n_jobs=4)\n",
    "train_scores_mean = -np.mean(train_scores, axis=1)\n",
    "train_scores_std = np.std(train_scores, axis=1)\n",
    "test_scores_mean = -np.mean(test_scores, axis=1)\n",
    "test_scores_std = np.std(test_scores, axis=1)\n",
    "\n",
    "plt.title(\"Validation Curve with Logistic Regression\")\n",
    "plt.xlabel(\"C\")\n",
    "plt.ylabel(\"Score\")\n",
    "plt.ylim(0.0, 1.1)\n",
    "lw = 2\n",
    "plt.semilogx(param_range, train_scores_mean, label=\"Training score\",\n",
    "             color=\"darkorange\", lw=lw)\n",
    "plt.fill_between(param_range, train_scores_mean - train_scores_std,\n",
    "                 train_scores_mean + train_scores_std, alpha=0.2,\n",
    "                 color=\"darkorange\", lw=lw)\n",
    "plt.semilogx(param_range, test_scores_mean, label=\"Cross-validation score\",\n",
    "             color=\"navy\", lw=lw)\n",
    "plt.fill_between(param_range, test_scores_mean - test_scores_std,\n",
    "                 test_scores_mean + test_scores_std, alpha=0.2,\n",
    "                 color=\"navy\", lw=lw)\n",
    "plt.legend(loc=\"best\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Кривая валидации спускается вниз, но после C=100 снова поднимается наверх и начинается переобучение. Т.о. выбранное через GridSearch значение оптимально."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Байесовские методы\n",
    "Попробуем теперь байесовские методы.\n",
    "\n",
    "Как и в случае логистической регрессии начнём с самого простого варианта: CountVectorizer + MultinomialNB без настройки параметров."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_count_clf = MultinomialNB().fit(count_train, y_train)\n",
    "print_results(logit_count_clf, count_val, y_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Результат очень похож на результаты логистической регрессии. Попробуем настроить параметры alpha"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_count_grid = GridSearchCV(MultinomialNB(), {'alpha': [0.01, 0.01, 0.1, 1]}, scoring='neg_log_loss')\n",
    "nb_count_grid.fit(count_train, y_train)\n",
    "print(nb_count_grid.best_params_, ', cross validation score:', -nb_count_grid.best_score_)\n",
    "print_results(nb_count_grid, count_val, y_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Попробуем TF-IDF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_tfidf_grid = GridSearchCV(MultinomialNB(), {'alpha': [0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1]}, scoring='neg_log_loss')\n",
    "nb_tfidf_grid.fit(tfidf_train, y_train)\n",
    "print(nb_tfidf_grid.best_params_, ', cross validation score:', -nb_tfidf_grid.best_score_)\n",
    "print_results(nb_tfidf_grid, tfidf_val, y_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Результат лучше. Попробуем биграммы."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_tfidf_12_grid = GridSearchCV(MultinomialNB(), {'alpha': [0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1]}, scoring='neg_log_loss')\n",
    "nb_tfidf_12_grid.fit(tfidf_12_train, y_train)\n",
    "print(nb_tfidf_12_grid.best_params_, ', cross validation score:', -nb_tfidf_12_grid.best_score_)\n",
    "print_results(nb_tfidf_12_grid, tfidf_12_val, y_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "И триграммы."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_tfidf_13_grid = GridSearchCV(MultinomialNB(), {'alpha': [0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1]}, scoring='neg_log_loss')\n",
    "nb_tfidf_13_grid.fit(tfidf_13_train, y_train)\n",
    "print(nb_tfidf_13_grid.best_params_, ', cross validation score:', -nb_tfidf_13_grid.best_score_)\n",
    "print_results(nb_tfidf_13_grid, tfidf_13_val, y_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Как и в случае логистической регрессии, триграммы работают хуже, оставим биграммы.\n",
    "\n",
    "Добавим наши фичи."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_tfidf_12_ext_grid = GridSearchCV(MultinomialNB(), {'alpha': [0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1]}, scoring='neg_log_loss')\n",
    "nb_tfidf_12_ext_grid.fit(tfidf_12_ext_train, y_train)\n",
    "print(nb_tfidf_12_ext_grid.best_params_, -nb_tfidf_12_ext_grid.best_score_)\n",
    "print_results(nb_tfidf_12_ext_grid, tfidf_12_ext_val, y_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Log-loss чуть лучше, почти достиг 0.41."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Таким образом, лучше всего сработала байесовская модель, обученная на TF-IDF с биграммами и с добавлением новых признаков."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Предсказания на тестовых данных\n",
    "\n",
    "Построим теперь предсказания на тестовых данных.\n",
    "\n",
    "Обучим самую удачную модель на train + validation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfidf_12_test = tfidf_12_vectorizer.transform(text_test)\n",
    "tfidf_12_ext_test = hstack([tfidf_12_test, X_test[['length', 'unique', 'word_length', 'punctuation']]]).tocsr()\n",
    "\n",
    "nb = MultinomialNB(alpha=0.01)\n",
    "nb.fit(tfidf_12_ext_train_val, y_train_val)\n",
    "print_results(nb, tfidf_12_ext_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Т.о. финальный результат на тестовой выборке: 0.376486. Результат на тествой выборке получился даже лучше, чем на валидации (0.410010, возможно, из-за того, что обучали на выборке большего размера). Тестовая выборка была получена с рандомизацией при помощи train_test_split."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Выводы\n",
    "\n",
    "Итак, обучение прошло успешно, переобучения не произошло.\n",
    "\n",
    "Как упоминалось в начале, данная модель может использоваться для определения авторов книг и статей, желавших остаться анонимными. Также, в более широком смысле, такие методы могут использоваться для любой классификации текста (например, для классификации заданных через сайт вопросов по темам).\n",
    "\n",
    "В качестве дальнейших улучшений мне видится:\n",
    "* Объединить предсказания нескольких успешных моделей (например, логистической регрессии, байеса и чего-нибудь ещё).\n",
    "* Можно попробовать рекуррентные нейронные сети. Сейчас мы учитываем только биграммы, то есть контекст теряется. Триграммы \"в лоб\" дают не очень хороший результат. Может быть, RNN смогла бы как-то более тонко учесть контекст и повысить качество предсказаний. С другой стороны, выборка небольшая, может и не хватить данных для обучения нейронной сети.\n",
    "* Можно попробовать как-то корректировать предсказания (clipping). Но это вряд ли даст сильные изменения в реальной задаче, скорее просто может немного улучшить log-loss."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
